{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adversarial Robustness Toolbox for Provenance-Based Defenses\n",
    "\n",
    "In this notebook we will learn how to use ART to defend against adversarial attacks in IoT settings.\n",
    "\n",
    "When data is collected from multiple sources, we can use **provenance features** to track the origin of that data. Using those features, we can defend models against malicious attacks. We will also show how to use the Reject on Negative Impact (RONI) defense method within ART."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "import os, sys\n",
    "from os.path import abspath\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "from art.attacks.poisoning.poisoning_attack_svm import PoisoningAttackSVM\n",
    "from art.estimators.classification.scikitlearn import ScikitlearnSVC\n",
    "from art.defences.detector.poison import ProvenanceDefense, RONIDefense\n",
    "from art.utils import load_mnist\n",
    "from sklearn.svm import SVC\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "np.random.seed(301)\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_training = 40\n",
    "num_poison = 5\n",
    "num_valid = 40 # the number of valid examples for the attacker\n",
    "num_trusted = 25 # the number of trusted data for the defender\n",
    "num_devices = 4 # last device is inserting poison\n",
    "kernel = 'linear' # available kernels are 'rbf', 'poly' and 'linear'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and transform MNIST data\n",
    "\n",
    "In this examples we are training a classifier that differentiates between the number 4 and the number 0. The training data is split between the first `num_devices - 1` devices and the poisoned training data is the added to the last device. Quantity fo data and model kernel are specified by hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()\n",
    "y_train = np.argmax(y_train, axis=1)\n",
    "y_test = np.argmax(y_test, axis=1)\n",
    "zero_or_four = np.logical_or(y_train == 4, y_train == 0)\n",
    "x_train = x_train[zero_or_four]\n",
    "y_train = y_train[zero_or_four]\n",
    "tr_labels = np.zeros((y_train.shape[0], 2))\n",
    "tr_labels[y_train == 0] = np.array([1, 0])\n",
    "tr_labels[y_train == 4] = np.array([0, 1])\n",
    "y_train = tr_labels\n",
    "\n",
    "\n",
    "zero_or_four = np.logical_or(y_test == 4, y_test == 0)\n",
    "x_test = x_test[zero_or_four]\n",
    "y_test = y_test[zero_or_four]\n",
    "te_labels = np.zeros((y_test.shape[0], 2))\n",
    "te_labels[y_test == 0] = np.array([1, 0])\n",
    "te_labels[y_test == 4] = np.array([0, 1])\n",
    "y_test = te_labels\n",
    "\n",
    "n_samples_train = x_train.shape[0]\n",
    "n_features_train = x_train.shape[1] * x_train.shape[2] * x_train.shape[3]\n",
    "n_samples_test = x_test.shape[0]\n",
    "n_features_test = x_test.shape[1] * x_test.shape[2] * x_test.shape[3]\n",
    "\n",
    "x_train = x_train.reshape(n_samples_train, n_features_train)\n",
    "x_test = x_test.reshape(n_samples_test, n_features_test)\n",
    "x_train = x_train[:num_training]\n",
    "y_train = y_train[:num_training]\n",
    "\n",
    "trusted_data = x_test[:num_trusted]\n",
    "trusted_labels = y_test[:num_trusted]\n",
    "x_test = x_test[num_trusted:]\n",
    "y_test = y_test[num_trusted:]\n",
    "valid_data = x_test[:num_valid]\n",
    "valid_labels = y_test[:num_valid]\n",
    "x_test = x_test[num_valid:]\n",
    "y_test = y_test[num_valid:]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add provenance data and poison samples\n",
    "\n",
    "*Note:* In real application scenarios, provenance data is also loaded. Provenance data is generated for this experiment for demonstration purposes.\n",
    "\n",
    "This code will take longer to run depending on the number of poison samples you allow. Each samples is being generated independently, iteratively maximizing the generalization loss of the original SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# assign random provenance features to the original training points\n",
    "clean_prov = np.random.randint(num_devices - 1, size=x_train.shape[0])\n",
    "p_train = np.eye(num_devices)[clean_prov]\n",
    "\n",
    "no_defense = ScikitlearnSVC(model=SVC(kernel=kernel), clip_values=(min_, max_))\n",
    "no_defense.fit(x=x_train, y=y_train)\n",
    "# poison a predetermined number of points starting at training points\n",
    "poison_points = np.random.randint(no_defense._model.support_vectors_.shape[0], size=num_poison)\n",
    "all_poison_init = np.copy(no_defense._model.support_vectors_[poison_points])\n",
    "poison_labels = np.array([1,1]) - no_defense.predict(all_poison_init)\n",
    "\n",
    "\n",
    "svm_attack = PoisoningAttackSVM(classifier=no_defense, x_train=x_train, y_train=y_train,\n",
    "                                step=0.1, eps=1.0, x_val=valid_data, y_val=valid_labels, max_iter=200)\n",
    "\n",
    "poisoned_data, _ = svm_attack.poison(all_poison_init, y=poison_labels)\n",
    "\n",
    "# Stack on poison to data and add provenance of bad actor\n",
    "all_data = np.vstack([x_train, poisoned_data])\n",
    "all_labels = np.vstack([y_train, poison_labels])\n",
    "poison_prov = np.zeros((num_poison, num_devices))\n",
    "poison_prov[:,num_devices - 1] = 1\n",
    "all_p = np.vstack([p_train, poison_prov])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Poison\n",
    "\n",
    "By changing the value of `idx` from 0 to `num_poison - 1` you can visualize each poison sample. Notice how they attempt to add features from the other class to confuse the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAEbCAYAAADaoTEAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASO0lEQVR4nO3df5BdZX3H8fdnN5us+QEkICGE2AhSlbESnC0yQts4mSowIDDVtrFFap2GTsGBGSowtB2wYy1lAGvHFicByq8IZUQUa9pCIy1SHcaAEZKGFmoRQtZEiCFLQshm99s/7olsw93n3N3749zN83nNZPbmfO+595uT3c+eH899jiICM8tXT9UNmFm1HAJmmXMImGXOIWCWOYeAWeYcAmaZcwgchCRdJenmqvtohKQvS/qzqvvImTxOoHtJeg6YD4wAu4A1wKcj4tUq+6pHUgC7gQBeAf4B+ExEjLTwPW4DNkfEn7bqNc17AlPB2RExG3gf8MtAN/8AnFj0ugz4OPAHFfdjDXAITBER8SLwT8B7ACQdLekBSdslPSvp5z9wkq6RdFfxuF/SXZJelrRD0vclzW/wNe6VdIekIUkbJQ002OvTwHfG9PpuSf9WvP9GSR8Z8z63Sfpc8XippM2SLpO0TdKgpE8WtRXA7wCXS3pV0jeb2Z72BofAFCFpEXAm8INi0d3AZuBo4KPA5yUtq7PqBcChwCLgcOAPgdcafI2PAPcAhwEPAF9qsNcTgF8BfiCpD/gm8CBwJPBpYLWkd46z+lFFvwuBTwF/K2luRKwEVgPXRcTsiDi7kV6snEOg+31d0g7gUeDfqf2gLgJOA66IiD0RsR64GTi/zvrD1H743xERIxHxeETsbPA1Ho2INcVx/Z3AiSW9PiHpZ9R+6G8G/h44BZgNXBsReyPi28A/AsvHeY1h4M8jYjgi1gCvAuMFhrXAtKobsFLnRsS/jl0g6Whge0QMjVn8Y6De7vqd1PYC7pF0GHAX8CfUfvuXvcZPxjzeDfRLmhYR+8bp9X0R8WydXl+IiNED3mfhOK/x8gGvv5taiFibeE9gatoCzJM0Z8yytwEvHvjE4jfqZyPiBOADwFnAJybyGi3odZGksd9rk30fX8pqA4fAFBQRLwDfBf6yOPH3XmrHz6sPfK6kD0r6JUm9wE5qu9sjE3mNJj1G7fLm5ZL6JC0FzqZ2rmGitgLHtrA3wyEwlS0HFlP7TXs/cHVEPFTneUcBX6UWAJuonVe4a4KvMWkRsZfaCcYzgJeAvwM+UVxBmKhbgBOKqwxfb2GbWfNgIbPMeU/ALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMVRICkk6X9F/FDLdXVtFDiqTnJD0lab2kdV3Qz63F7LsbxiybJ+khSc8UX+d2WX/XSHqx2IbrJZ1ZYX+LJD0saVMx2/ElxfKu2IaJ/jqyDTs+n0Axw81/A79Obabb7wPLI+I/O9pIQnHTj4GIeKnqXgAk/Sq1CTfviIj903hfR22OwGuLIJ0bEVd0UX/XAK9GxPVV9DSWpAXAgoh4ophO7XHgXOD36IJtmOjvN+nANqxiT+Bk4NmI+FEx68w9wDkV9DFlRMQjwPYDFp8D3F48vp3aN00lxumva0TEYEQ8UTweojbD0kK6ZBsm+uuIKkJgIfDCmL9vpoP/4AYF8KCkx4ubXnSj+RExCLVvImpz+nebiyU9WRwuVHa4MpakxcBJ1OY+7LpteEB/0IFtWEUIqM6ybpvj7NSIeB+1efEuKnZ3bWJuAo4DlgCDwA3VtgOSZgP3AZdGxM6q+zlQnf46sg2rCIHN1ObB3+8YahNddo2I2FJ83UZtAs6Tq+2orq3FseT+Y8ptFffz/0TE1uJmJ6PAKirehsWdkO4DVkfE14rFXbMN6/XXqW1YRQh8Hzhe0tslTQd+m9otrrqCpFn75+KXNAv4ELAhvVYlHqB2izGKr9+osJc32f/DVTiPCrehJFGbqXhTRNw4ptQV23C8/jq1DSuZbbi41PHXQC9wa0T8RcebGIekY6n99ofaHZq+UnV/ku4GlgJHUJt7/2rg68C91G7k8TzwsYio5OTcOP0tpbYbG8BzwIX7j78r6O80ajdIfQrYfyekq6gdd1e+DRP9LacD29BTjptlziMGzTLnEDDLnEPALHMOAbPMOQTMMldpCHTxkFzA/TWrm/vr5t6gs/1VvSfQ1f8RuL9mdXN/3dwbdLC/qkPAzCrW1GAhSacDX6Q28u/miLg29fzpmhH9zPr534d5nT5mTPr92839Naeb++vm3qD1/e1hF3vj9Xof3pt8CExmcpBDNC/er2WTej8zm7zHYi07Y3vdEGjmcMCTg5gdBJoJgakwOYiZlZjWxLoNTQ5SXOpYAdDPzCbezszaoZk9gYYmB4mIlRExEBED3XwixixXzYRAV08OYmaNmfThQETsk3Qx8C+8MTnIxpZ1ZmYd0cw5ASJiDbCmRb2YWQU8YtAscw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDV1a3Kzidjyxx9I1h/49HXJ+h07Tk7Wv/mlX0vWD1/1vWQ9V02FgKTngCFgBNgXEQOtaMrMOqcVewIfjIiXWvA6ZlYBnxMwy1yzIRDAg5Iel7SiFQ2ZWWc1ezhwakRskXQk8JCkpyPikbFPKMJhBUA/M5t8OzNrtab2BCJiS/F1G3A/8KbTtxGxMiIGImKgjxnNvJ2ZtcGkQ0DSLElz9j8GPgRsaFVjZtYZzRwOzAful7T/db4SEf/ckq4OUs988ZRk/Z2rdiTroxuebmU7Hbfik99K1n+875Bk/ZxDfpCsf2/525P1WJUsZ2vSIRARPwJObGEvZlYBXyI0y5xDwCxzDgGzzDkEzDLnEDDLnEPALHOeT2ACeo84PFnfsvydyfqtZ305WX/+w/OS9dXvOiZZ73bz+9LjIHaMpoeVvzwyO1l/4WeHJevH8GKynivvCZhlziFgljmHgFnmHAJmmXMImGXOIWCWOYeAWeY8TmACXn/v4mT94cuvT9Z3jI4m6/N6dyfrq5na4wQW96Unpf7pyJxkfdforGR9eK+/nSfDewJmmXMImGXOIWCWOYeAWeYcAmaZcwiYZc4hYJY5X1idgJ69I8n6rkiPAxhGyfr2kYP7Nm2Lp+1N1ocjPU5i12j6DlbT+tL/P1af9wTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzHicwAT37SsYBRHr9PdGbrO8tqU91M5T+ndOvfcn6UdNeSdZ7ekr+A6yu0j0BSbdK2iZpw5hl8yQ9JOmZ4uvc9rZpZu3SyOHAbcDpByy7ElgbEccDa4u/m9kUVBoCEfEIsP2AxecAtxePbwfObXFfZtYhkz0xOD8iBgGKr0e2riUz66S2nxiUtAJYAdDPwf0BGbOpaLJ7AlslLQAovm4b74kRsTIiBiJioI/0p8DMrPMmGwIPABcUjy8AvtGadsys00oPByTdDSwFjpC0GbgauBa4V9KngOeBj7WzyW6h19OfVy+bL2A40pk7NPqWZL3nxHcn66M/3JSsl1K6f6LkOnzJ+sNl8y3E9GS9l/T6M2ek5yuw+kpDICKWj1Na1uJezKwCHjZsljmHgFnmHAJmmXMImGXOIWCWOYeAWeY8n8AEDC49NFkfifR18j2R3txDo/3J+rNXpEdcHvvxZLn0Ov6O3z0lWR96W3r9d334mWT9ldH0OIORknEWvaTX7+1JjyNoehzEQcp7AmaZcwiYZc4hYJY5h4BZ5hwCZplzCJhlziFgljmPE5iAuWdsSdaHSzK1bL6AoZF0/TMnPZis796Y/jz+4ukvJeun9T+arM9UX7L+ymj68/zbR9P3VRgp/Z2UHgcw53NzStYvkek4Au8JmGXOIWCWOYeAWeYcAmaZcwiYZc4hYJY5h4BZ5jxOYIzR05Yk63+0+P5kfWg0fZ3+5ZHZyfpwpK+jv7VnZ7J+6lv+J1mf35u+zt6n9O+E7SXjAF4pGQewq2Q+hTKzNJysT9vxWrI+UvLvo+S+CAcr7wmYZc4hYJY5h4BZ5hwCZplzCJhlziFgljmHgFnmPE5gjJ7/+GGy/tmnzkrWbzjxq8n6ntH05/Fn9+5J1vt70tfJ+5S+zl027/+eko/L74n0OIhdJfXdo+n7JszseT1Zf3rvUcn68LyZyXpP2TiAg3S+gDKlewKSbpW0TdKGMcuukfSipPXFnzPb26aZtUsjhwO3AafXWf6FiFhS/FnT2rbMrFNKQyAiHgG2d6AXM6tAMycGL5b0ZHG4MLdlHZlZR002BG4CjgOWAIPADeM9UdIKSeskrRsmfeLHzDpvUiEQEVsjYiQiRoFVwMmJ566MiIGIGOgjfXbYzDpvUiEgacGYv54HbBjvuWbW3UrHCUi6G1gKHCFpM3A1sFTSEiCA54AL29hj55RcJz7mNzYm61ef//vJ+nsuemrCLY01py89juCH096WrPdpJFn/yuplyfrMn6S3z4VXpudbOKx3d7LeXzJfwIbXjknWf1TyXfiO7+Q5DqBMaQhExPI6i29pQy9mVgEPGzbLnEPALHMOAbPMOQTMMucQMMucQ8Asc55PoIUOu/N7yfrmO9vdQfrz/GUW8t2m1v+rd52XrF//0duT9bL5BD566Lpkfe3hv5isW33eEzDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8x5nIC9QUrXy+ZbWLs3WT/yt4aS9emk7wuwJ9Lfrr3yfAGT4T0Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnMcJ5KRsHECTBk9N32FqT/Ql6ztH+5P1Q3rS913o8TiBSfGegFnmHAJmmXMImGXOIWCWOYeAWeYcAmaZcwiYZc7jBOwNJfMFlNFIur5tZE6y3q/hdJ10fWZfej4DjyKor3RPQNIiSQ9L2iRpo6RLiuXzJD0k6Zni69z2t2tmrdbI4cA+4LKIeDdwCnCRpBOAK4G1EXE8sLb4u5lNMaUhEBGDEfFE8XgI2AQsBM4B9t9X6nbg3HY1aWbtM6ETg5IWAycBjwHzI2IQakEBHNnq5sys/RoOAUmzgfuASyNi5wTWWyFpnaR1w6RvOGlmnddQCEjqoxYAqyPia8XirZIWFPUFwLZ660bEyogYiIiBPtKfMjOzzmvk6oCAW4BNEXHjmNIDwAXF4wuAb7S+PTNrt0bGCZwKnA88JWl9sewq4FrgXkmfAp4HPtaeFq1lmhwHUGbJGZuS9aGRtyTr06elBxoc3vNast7fuy9ZT6+dr9IQiIhHgfFmo1jW2nbMrNM8bNgscw4Bs8w5BMwy5xAwy5xDwCxzDgGzzHk+AWuZkw/732S9vyc9H8Bbe9Oj0Q/tSY8jmO5xApPiPQGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucxwlMJRrvE92FNs8XUOaoaa8k6yfMGEyv35seB/D5bUuT9c1fOj5Zn8PLyXquvCdgljmHgFnmHAJmmXMImGXOIWCWOYeAWeYcAmaZ8ziBqaTicQBlbrosfeuJb335b5L12T2zkvVvrz45WV9wz3eTdavPewJmmXMImGXOIWCWOYeAWeYcAmaZcwiYZc4hYJY5RQevPR+iefF++W7m4+ry+QJs6nos1rIzttf9BivdE5C0SNLDkjZJ2ijpkmL5NZJelLS++HNmqxs3s/ZrZMTgPuCyiHhC0hzgcUkPFbUvRMT17WvPzNqtNAQiYhAYLB4PSdoELGx3Y2bWGRM6MShpMXAS8Fix6GJJT0q6VdLccdZZIWmdpHXDvN5Us2bWeg2HgKTZwH3ApRGxE7gJOA5YQm1P4YZ660XEyogYiIiBPma0oGUza6WGQkBSH7UAWB0RXwOIiK0RMRIRo8AqIP0RLzPrSo1cHRBwC7ApIm4cs3zBmKedB2xofXtm1m6NXB04FTgfeErS+mLZVcBySUuAAJ4DLmxLhweTsnEAZhVo5OrAo0C97941rW/HzDrNw4bNMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5zvO9BJng/AupD3BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMdve+ApJ8CPx6z6AjgpY41MHHurznd3F839wat7+8XIuKt9QodDYE3vbm0LiIGKmughPtrTjf31829QWf78+GAWeYcAmaZqzoEVlb8/mXcX3O6ub9u7g062F+l5wTMrHpV7wmYWcUcAmaZcwiYZc4hYJY5h4BZ5v4PPqhQrabQ6K4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAEbCAYAAADaoTEAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASkElEQVR4nO3deZAc5X3G8e8jISRACCQLsLhtLIwdVyGoBcxRiVLCGFOhgFTAyBiLGEoU5lIKEhNCDCSBIpw+MCqLI4hwmPuq4owKh+CAbAEykhA2NggQyBJCERIGCR2//NG99miZ7dmdq2f3fT5VUzvbb0/3b3t3nn175p23FRGYWbqGlF2AmZXLIWCWOIeAWeIcAmaJcwiYJc4hYJY4h0CHknS+pBuavW4fthWSPteMbZVB0kRJi8uuYyBxCLSBpJMkzZP0oaTfS5ouaduix0TEpRFxSl+23591GyHpZ5LWSPqg4vZwq/drreUQaDFJ5wD/Dvw9sA3wZWA34ElJm/fymM3aV2G/nRERIytuR5ZdkDXGIdBCkkYBFwNnRsRjEbEuIhYBx5EFwTfz9S6SdI+kWyWtAk7Kl91asa1vSXpD0nuS/lnSIkmHVjz+1vz+7nmXfoqkNyUtl/RPFdvZX9KzklZKWiLp2t7CqJ8/63clPdcdYJJOk7RA0oj8+7vzXtD7kp6W9GcVj71Z0nWSHs17Fz+X9GlJ35f0f5JekbRPxfqLJP2jpJfz9v/o3k+VunaUdK+kdyW9LumsRn/WwcYh0FoHASOA+yoXRsQHwKPAVyoWHwXcA2wL3Fa5vqQvAtcBJwDjyHoUO9XY9yHA54FJwPckfSFfvgH4O2AscGDe/p1+/lzVXAF8DFwgaTxwKfDNiFiTtz8KjAe2B16gx89IFowX5HWtBZ7N1xtLdlyu7rH+CcBXgT2APfPHbkLSEOBh4Fdkx2sSME3SVxv5QQcbh0BrjQWWR8T6Km1L8vZuz0bEAxGxMSI+6rHu3wAPR8QzEfEx8D2g1oc+Lo6IjyLiV2RPgr0BIuL5iHguItbnvZKfAH/Rj5/ph3kvovv2r/l2NwLfAs4CHgIuj4gXux8UETdFxOqIWAtcBOwtaZuK7d6f17YGuB9YExG3RMQG4E5gHzZ1bUS8FRErgEuAyVVq3Q/YLiL+JSI+jojXgOuB4/vx8w56nXzuORgsB8ZK2qxKEIzL27u9VbCdHSvbI+JDSe/V2PfvK+5/CIwEkLQn2X/VLmBLsr+B52tsq9JZEVH1nYiIWCTpKeAI4MfdyyUNJXuiHgtsB2zMm8YC7+f3l1Zs6qMq34/ssbvK4/UG2THqaTdgR0krK5YNBf6nWv2pck+gtZ4l69r+deVCSVsBXwNmVSwu+s++BNi54vFbAJ+qs6bpwCvA+IgYBZwPqM5tbULSEWSnGLPITg+6fYPsdOdQslOZ3bsf0sDudqm4vyvwTpV13gJej4htK25bR8QRDex30HEItFBEvE/2wuCPJB0uaZik3YG7gcXAf/ZxU/cAR0o6KH8R72LqfwJtDawCPpC0F3BandvZhKSxwI3AKcCUvN7uJ9vWZGH4Hlnv49Im7PJ0STtLGkMWZHdWWecXwKr8RcstJA2V9CVJ+zVh/4OGQ6DFIuJysj/SK8mefLPJ/kNNys+P+7KNBcCZwE/JegWrgWVkT6z+OpfsP/NqsvPjak+eItf2GCfQfSoxA3gwIh6JiPeAk4EbJH0KuIWsy/428DLwXB1193Q78ATwWn77t54r5K8nHAlMAF4nO/26gaw3Yjl5UpGBR9JIYCVZl/71sutpN0mLgFMi4r/KrmUwcE9ggJB0pKQt89cTrgTmAYvKrcoGA4fAwHEU2Ytf75C93358uBtnTeDTAbPEuSdgljiHgFniHAJmiXMImCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZokrJQTyCTZ+Lem3ks4ro4Yi+Wy28yTNlTSnA+q5SdIySfMrlo2R9KSkV/OvozusvoskvZ0fw7kVE4yUUd8ukp6StDCfAfnsfHlHHMOC+tpyDNv+AaJ8vrnfkM20uxj4JTA5Il5uayEF8s+rd0XE8lrrtoOkPwc+AG6JiC/lyy4HVkTEZXmQjo6I73ZQfRcBH0TElWXUVEnSOGBcRLwgaWuyORWPBk6iA45hQX3H0YZjWEZPYH/gtxHxWj5z7k/JPiZrvYiIp4EVPRYfBczM788k+6MpRS/1dYyIWBIRL+T3VwMLyaYg74hjWFBfW5QRAjux6Uyxi2njD9xHATwh6XlJU8suphc7RMQSyP6IyObz7zRnSHopP10o7XSlUj7H4z5k07x13DHsUR+04RiWEQLVJsjstEkNDo6IfclmBD497+5a/0wnuzDIBLJ5Ea8qt5w/Tst2LzAtIlaVXU9PVepryzEsIwQWs+l00TtTfbro0kTEO/nXZWQXwti/3IqqWpqfS3afUy4ruZ5NRMTSiNiQX5Tkeko+hpKGkT3BbouI7itCdcwxrFZfu45hGSHwS2C8pM/k02cfT3bFmo4gaav8xZnu6wMcBswvflQpHiKb2pv864Ml1vIJ3U+u3DGUeAwliWw69IURUXk5s444hr3V165jWMr0YvlbHd8nuxrMTRFxSduL6IWkz5L994fs6jy3l12fpDuAiWRX7FkKXAg8ANxFduGNN4Fj80tydUp9E8m6sUE2Ieqp3effJdR3CNlVh+bxp6sfnU923l36MSyobzJtOIaeY9AscR4xaJY4h4BZ4hwCZolzCJglziFglrhSQ6CDh+QCrq9RnVxfJ9cG7a2v7J5AR/8icH2N6uT6Ork2aGN9ZYeAmZWsocFCkg4HfkA28u+GiLisaP3NNTxGsNUfv1/HWoYxvO79t5rra0wn19fJtUHz61vDH/g41lb78F79IVDP5CCjNCYO0KS69mdm9Zsds1gVK6qGQCOnA54cxGwQaCQEBsLkIGZWw2YNPLZPk4Pkb3VMBRjBlg3szsxaoZGeQJ8mB4mIGRHRFRFdnfxCjFmqGgmBjp4cxMz6pu7TgYhYL+kM4HH+NDnIgqZVZmZt0chrAkTEI8AjTarFzErgEYNmiXMImCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZolzCJglziFgljiHgFniHAJmiXMImCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZolzCJglziFgljiHgFniHAJmiXMImCXOIWCWOIeAWeIcAmaJcwiYJa6hS5Ob9cdrlx9Y2P7yCdcWtl+9Yq/C9juvO7Swfbvpzxa2p6qhEJC0CFgNbADWR0RXM4oys/ZpRk/gLyNieRO2Y2Yl8GsCZolrNAQCeELS85KmNqMgM2uvRk8HDo6IdyRtDzwp6ZWIeLpyhTwcpgKMYMsGd2dmzdZQTyAi3sm/LgPuB/avss6MiOiKiK5hDG9kd2bWAnWHgKStJG3dfR84DJjfrMLMrD0aOR3YAbhfUvd2bo+Ix5pS1SD16g8PKGzf4661he1DnpnbzHKaTsOLe3qPff2KwvaNNXqK08a8XNg+78SdCtvfnV7YnKy6QyAiXgP2bmItZlYCv0VoljiHgFniHAJmiXMImCXOIWCWOIeAWeI8n0A/DN1h+8L2t4//XGH7fUdeU9i+6q+K3ye/dM/iT2rH+vWF7WUbN3TzwvYNEQ1tf2Ooocenyj0Bs8Q5BMwS5xAwS5xDwCxxDgGzxDkEzBLnEDBLnMcJ9MNHE3YtbP/FP/ygxhZqZe664mYN7MwequL38RsdJ2D1Gdh/VWbWMIeAWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZonzOAFrGtUYBzCk1v8cbSxs9jiC1nBPwCxxDgGzxDkEzBLnEDBLnEPALHEOAbPEOQTMEudxAjZg1JqPwOpTsycg6SZJyyTNr1g2RtKTkl7Nv45ubZlm1ip9OR24GTi8x7LzgFkRMR6YlX9vZgNQzRCIiKeBFT0WHwXMzO/PBI5ucl1m1ib1vjC4Q0QsAci/Fl+kz8w6VstfGJQ0FZgKMIItW707M+unensCSyWNA8i/LuttxYiYERFdEdE1jOKr7ppZ+9UbAg8BU/L7U4AHm1OOmbVbzdMBSXcAE4GxkhYDFwKXAXdJOhl4Ezi2lUWmotbn5Vd+fd/C9m1ufa6Z5Qw4Q+T5BupRMwQiYnIvTZOaXIuZlcDDhs0S5xAwS5xDwCxxDgGzxDkEzBLnEDBLnOcT6Id3DhlW6v4/e9qvC9tXPjSqsH3DqlWF7Sv+9sDC9g92Lf48/z6HLSxs30jxdQWsHO4JmCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4jxOoB8OOHRBQ4+vNV9ALTfs9nhh+7MvblHYvi6Kf937Dv95YfvoISMK28seB7Doqs8Xtm/F7DZVMrC4J2CWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZonzOIEKa47cv7D9wh2vrrGFcq+wdMiINQ1uYfPC1rLHAdQyfOX6sksYkNwTMEucQ8AscQ4Bs8Q5BMwS5xAwS5xDwCxxDgGzxHmcQIWRL75d2H7ma8cVtt8z/oFmljPo1JpPYaiKr2uw8OPicQprxhT/OY8sbE1XzZ6ApJskLZM0v2LZRZLeljQ3vx3R2jLNrFX6cjpwM3B4leXXRMSE/PZIc8sys3apGQIR8TSwog21mFkJGnlh8AxJL+WnC6ObVpGZtVW9ITAd2AOYACwBruptRUlTJc2RNGcda+vcnZm1Sl0hEBFLI2JDRGwErgd6/fhdRMyIiK6I6BpW8qfszOyT6goBSeMqvj0GmN/bumbW2WqOE5B0BzARGCtpMXAhMFHSBCCARcCpLayxbdYvLh4nsNk3Pl3Yvt9J0wrbL/n2Lf2uqT+GNPh5/3Pun1LYPvy94vfx7/vOFYXtOw8d1u+aKv1u3XaF7Ru/vbx4A3c3tPtBq2YIRMTkKotvbEEtZlYCDxs2S5xDwCxxDgGzxDkEzBLnEDBLnEPALHGKGp/xbqZRGhMHaFLb9mft9ZsbugrbFxx+XWF7rfkEajn5ja8Utr970MqGtj+QzY5ZrIoVVQ+wewJmiXMImCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4X3fAmmb7/y6eL2Do1xobB2Ct4Z6AWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZolzCJglzuMErGmWHbih7BKsDu4JmCXOIWCWOIeAWeIcAmaJcwiYJc4hYJY4h4BZ4jxOwJpn842FzRsavMZFo9clsOpq9gQk7SLpKUkLJS2QdHa+fIykJyW9mn8d3fpyzazZ+nI6sB44JyK+AHwZOF3SF4HzgFkRMR6YlX9vZgNMzRCIiCUR8UJ+fzWwENgJOAqYma82Ezi6VUWaWev064VBSbsD+wCzgR0iYglkQQFs3+zizKz1+hwCkkYC9wLTImJVPx43VdIcSXPWsbaeGs2shfoUApKGkQXAbRFxX754qaRxefs4YFm1x0bEjIjoioiuYQxvRs1m1kR9eXdAwI3Awoi4uqLpIWBKfn8K8GDzyzOzVuvLOIGDgROBeZLm5svOBy4D7pJ0MvAmcGxrSrROsdnOOxW2/2TizMJ260w1QyAingF6G6UxqbnlmFm7ediwWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZonzfALWZzFi88L2Q0b8oaHte76AcrgnYJY4h4BZ4hwCZolzCJglziFgljiHgFniHAJmifM4AeszbSi+rsDyjR8Xtu8wtLGZpS5Yun9h++9+tFdh+yiea2j/g5V7AmaJcwiYJc4hYJY4h4BZ4hwCZolzCJglziFgljiPE7A+W//6G4Xtk889t7D9Z9f8uKH9P37rgYXt4+7434a2nyr3BMwS5xAwS5xDwCxxDgGzxDkEzBLnEDBLnEPALHGKiLbtbJTGxAHy1czN2m12zGJVrKh6YYeaPQFJu0h6StJCSQsknZ0vv0jS25Lm5rcjml24mbVeX0YMrgfOiYgXJG0NPC/pybztmoi4snXlmVmr1QyBiFgCLMnvr5a0ENip1YWZWXv064VBSbsD+wCz80VnSHpJ0k2SRvfymKmS5kias461DRVrZs3X5xCQNBK4F5gWEauA6cAewASynsJV1R4XETMioisiuobR2ESTZtZ8fQoBScPIAuC2iLgPICKWRsSGiNgIXA8UTwVrZh2pL+8OCLgRWBgRV1csH1ex2jHA/OaXZ2at1pd3Bw4GTgTmSZqbLzsfmCxpAhDAIuDUllRoZi3Vl3cHngGqDTJ4pPnlmFm7ediwWeIcAmaJcwiYJc4hYJY4h4BZ4hwCZolzCJglziFgljiHgFniHAJmiXMImCXOIWCWOIeAWeIcAmaJa+t1ByS9C1Re5H4ssLxtBfSf62tMJ9fXybVB8+vbLSK2q9bQ1hD4xM6lORHRVVoBNbi+xnRyfZ1cG7S3Pp8OmCXOIWCWuLJDYEbJ+6/F9TWmk+vr5NqgjfWV+pqAmZWv7J6AmZXMIWCWOIeAWeIcAmaJcwiYJe7/ATQYWiRk8DY+AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx = 0\n",
    "plt.matshow(poisoned_data[idx].reshape(28, 28))\n",
    "plt.title(\"Poison Point\\n\")\n",
    "plt.matshow(all_poison_init[idx].reshape(28, 28))\n",
    "plt.title(\"Original Example\\n\")\n",
    "plt.clim(0,1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that sometimes the poison appears to be a number 4 with features of the number zero in the background. It may appear as a shadowed zero \"watermarking\" the four. The aim of inserting poisonous samples like these in the training set is to shift the decision boundary so actual 0s to also become classified as 4s. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train clean classifier and poisoned classifier\n",
    "perfect_defense = ScikitlearnSVC(model=SVC(kernel=kernel), clip_values=(min_, max_))\n",
    "perfect_defense.fit(x=x_train, y=y_train)\n",
    "no_defense.fit(x=all_data, y=all_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perfect defense accuracy (trusted set) 97.68%\n",
      "No defense accuracy (trusted set) 77.81%\n"
     ]
    }
   ],
   "source": [
    "perf_acc = np.average(np.all(perfect_defense.predict(x_test) == y_test, axis=1)) * 100\n",
    "no_acc = np.average(np.all(no_defense.predict(x_test) == y_test, axis=1)) * 100\n",
    "print(\"Perfect defense accuracy (trusted set) {0:.2f}%\".format(perf_acc))\n",
    "print(\"No defense accuracy (trusted set) {0:.2f}%\".format(no_acc))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply Defenses\n",
    "\n",
    "We will apply the following defenses to this poisoning attack:\n",
    "* **Perfect Defense** — All poison is detected and model is trained on clean data.\n",
    "* **Provenance-Based Defense with Trusted Data** — Poison is detected using the provenance defense algorithm specified above.\n",
    "* **Provenance-Based Defense without Trusted Data** — Assuming no validation data, just check each data segment for suspected poison.\n",
    "* **RONI Defense w/ Calibration** — Poison is detecting using RONI defense method (see below).\n",
    "* **RONI Defense w/o Calibration** — Suspicious poison is found by a threshold epsilon value\n",
    "* **No defense** — Model is trained with poisoned data\n",
    "\n",
    "### RONI Defense\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"../utils/data/images/roni.gif\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import HTML\n",
    "HTML('<img src=\"../utils/data/images/roni.gif\">')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [RONI (Reject on Negative Impact) defense method](https://www.usenix.org/legacy/event/leet08/tech/full_papers/nelson/nelson_html/#SECTION00051000000000000000) checks the empirical effect of each point on the performance of the classifier and removes suspicious points. Our is similar except instead of checking each point we check each set of points with the same provenance feature. We evaluate the defense with both the provenance defense and the perfect defense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "roni_defense = RONIDefense(no_defense, all_data, all_labels, trusted_data, trusted_labels)\n",
    "roni_defense.detect_poison()\n",
    "roni_no_cal = RONIDefense(no_defense, all_data, all_labels, trusted_data, trusted_labels)\n",
    "roni_no_cal.detect_poison()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Provenance Defense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"../utils/data/images/prov_defense.gif\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import HTML\n",
    "HTML('<img src=\"../utils/data/images/prov_defense.gif\">')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The provenenace defense method checks the effect of removing segments of the data that may come a bad actor intentionally poisoning the data. When a sector is found that is potentially poisonous, it is flagged as suspicious.\n",
    "\n",
    "In the trusted data version of the algorithm, the defender has some handpicked trusted data to test the performance of the model. In the version of the algorithm without trusted data, a random subset of training points from all segments are used as the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "prov_defense_trust = ProvenanceDefense(no_defense, all_data, all_labels, all_p, \n",
    "                                       x_val=trusted_data, y_val=trusted_labels, eps=0.1)\n",
    "prov_defense_trust.detect_poison()\n",
    "prov_defense_no_trust = ProvenanceDefense(no_defense, all_data, all_labels, all_p, eps=0.1)\n",
    "prov_defense_no_trust.detect_poison()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Defenses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_is_clean = np.array([1 if i < num_training else 0 for i in range(len(all_data))])\n",
    "def evaluate_defense(defense, name):\n",
    "    print(\"\\nEvaluating results of {} defense...\".format(name))\n",
    "    pc_tp = np.average(real_is_clean[:num_training] == defense.is_clean_lst[:num_training]) * 100\n",
    "    pc_tn = np.average(real_is_clean[num_training:] == defense.is_clean_lst[num_training:]) * 100\n",
    "    print(\"Percent of normal points correctly labeled (True Negative): {0:.2f}%\".format(pc_tp))\n",
    "    print(\"Percent of poison points correctly labeled (True Positive): {0:.2f}%\".format(pc_tn))\n",
    "    \n",
    "    classifier = ScikitlearnSVC(model=SVC(kernel=kernel), clip_values=(min_, max_))\n",
    "    mask = np.array(defense.is_clean_lst) == 1\n",
    "    classifier.fit(all_data[mask], all_labels[mask])\n",
    "    acc = np.average(np.all(classifier.predict(x_test) == y_test, axis=1)) * 100\n",
    "    print(\"Accuracy of classifier trained with {0:.2f} filter on test set\".format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Evaluating results of RONI w/o Calibration defense...\n",
      "Percent of normal points correctly labeled (True Negative): 95.00%\n",
      "Percent of poison points correctly labeled (True Positive): 0.00%\n",
      "Accuracy of classifier trained with 82.92 filter on test set\n",
      "\n",
      "Evaluating results of RONI w/ Calibration defense...\n",
      "Percent of normal points correctly labeled (True Negative): 100.00%\n",
      "Percent of poison points correctly labeled (True Positive): 20.00%\n",
      "Accuracy of classifier trained with 80.81 filter on test set\n",
      "\n",
      "Evaluating results of Provenance Defense w/o Trusted Data defense...\n",
      "Percent of normal points correctly labeled (True Negative): 70.00%\n",
      "Percent of poison points correctly labeled (True Positive): 100.00%\n",
      "Accuracy of classifier trained with 97.84 filter on test set\n",
      "\n",
      "Evaluating results of Provenance Defense w/ Trusted Data defense...\n",
      "Percent of normal points correctly labeled (True Negative): 100.00%\n",
      "Percent of poison points correctly labeled (True Positive): 100.00%\n",
      "Accuracy of classifier trained with 97.68 filter on test set\n"
     ]
    }
   ],
   "source": [
    "evaluate_defense(roni_no_cal, \"RONI w/o Calibration\")\n",
    "evaluate_defense(roni_defense, \"RONI w/ Calibration\")\n",
    "evaluate_defense(prov_defense_no_trust, \"Provenance Defense w/o Trusted Data\")\n",
    "evaluate_defense(prov_defense_trust, \"Provenance Defense w/ Trusted Data\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In [the paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8473440), we show that with only limited amounts of trusted data, you can still have a very powerful defense able to detect against bad actors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
